// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_TENSORCORE_H_
#define THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_TENSORCORE_H_

#include <mc_connections.h>
#include <systemc.h>

#include "src/AccelTypes.h"
#include "src/DoubleBuffer.h"
#include "src/InputController.h"
#include "src/MatrixProcessor.h"
#include "src/WeightController.h"
#include "src/VectorUnit.h"

template <typename IDTYPE, typename WDTYPE, typename ODTYPE, int NROWS,
          int NCOLS, int INP_BUF_SIZE, int WEIGHT_BUF_SIZE,
          int ACCUMULATION_BUFFER_SIZE>
SC_MODULE(TensorCore) {
  sc_in<bool> CCS_INIT_S1(clk);
  sc_in<bool> CCS_INIT_S1(rstn);

  Connections::In<Params> CCS_INIT_S1(paramsIn);
  Connections::In<VectorParams> CCS_INIT_S1(vectorParamsIn);

  InputController<IDTYPE, NROWS> CCS_INIT_S1(inputController);
  DoubleBuffer<IDTYPE, NROWS, INP_BUF_SIZE> CCS_INIT_S1(inputDoubleBuffer);
  Connections::Out<MemoryRequest> inputAddressRequest;
  Connections::In<Pack1D<IDTYPE, NROWS> > inputDataResponse;
  Connections::Combinational<int> inputBufferWriteAddress[2];
  Connections::Combinational<Pack1D<IDTYPE, NROWS> > inputBufferWriteData[2];
  Connections::Combinational<int> inputBufferWriteControl[2];
  Connections::Combinational<int> inputBufferReadAddress[2];
  Connections::Combinational<int> inputBufferReadControl[2];
  Connections::Combinational<Params> inputControllerParams;

  WeightController<WDTYPE, NROWS, NCOLS> CCS_INIT_S1(weightController);
  DoubleBuffer<WDTYPE, NCOLS, WEIGHT_BUF_SIZE> CCS_INIT_S1(weightDoubleBuffer);
  Connections::Out<MemoryRequest> weightAddressRequest;
  Connections::In<Pack1D<IDTYPE, NCOLS> > weightDataResponse;
  Connections::Combinational<int> weightBufferWriteAddress[2];
  Connections::Combinational<Pack1D<WDTYPE, NCOLS> > weightBufferWriteData[2];
  Connections::Combinational<int> weightBufferWriteControl[2];
  Connections::Combinational<int> weightBufferReadAddress[2];
  Connections::Combinational<int> weightBufferReadControl[2];
  Connections::Combinational<Params> weightControllerParams;

  MatrixProcessor<IDTYPE, WDTYPE, ODTYPE, NROWS, NCOLS,
                  ACCUMULATION_BUFFER_SIZE>
      CCS_INIT_S1(matrixProcessor);
  Connections::Combinational<Pack1D<IDTYPE, NROWS> > CCS_INIT_S1(
      inputsToSystolicArray);
  Connections::Combinational<Pack1D<WDTYPE, NCOLS> > CCS_INIT_S1(
      weightsToSystolicArray);
  Connections::Combinational<Pack1D<WDTYPE, NCOLS> > CCS_INIT_S1(
      outputsFromSystolicArray);
  Connections::Combinational<Params> CCS_INIT_S1(matrixProcessorParams);

  VectorUnit<ODTYPE, NCOLS, NROWS> CCS_INIT_S1(vectorUnit);
  Connections::Out<int> CCS_INIT_S1(vectorFetchAddressRequest);
  Connections::In<Pack1D<ODTYPE, NROWS> > CCS_INIT_S1(vectorFetchDataResponse);
  Connections::Out<Pack1D<ODTYPE, NROWS> > CCS_INIT_S1(vectorUnitOutput);
  Connections::Out<int> CCS_INIT_S1(outputAddress);
  Connections::SyncOut CCS_INIT_S1(done);

  SC_CTOR(TensorCore) {
    inputController.clk(clk);
    inputController.rstn(rstn);
    inputController.addressRequest(inputAddressRequest);
    inputController.dataResponse(inputDataResponse);
    inputController.paramsIn(inputControllerParams);

    inputDoubleBuffer.clk(clk);
    inputDoubleBuffer.rstn(rstn);

    for (int i = 0; i < 2; i++) {
      inputController.bufferWriteAddress[i](inputBufferWriteAddress[i]);
      inputController.bufferWriteData[i](inputBufferWriteData[i]);
      inputController.bufferWriteControl[i](inputBufferWriteControl[i]);
      inputController.bufferReadAddress[i](inputBufferReadAddress[i]);
      inputController.bufferReadControl[i](inputBufferReadControl[i]);

      inputDoubleBuffer.writeAddress[i](inputBufferWriteAddress[i]);
      inputDoubleBuffer.writeData[i](inputBufferWriteData[i]);
      inputDoubleBuffer.writeControl[i](inputBufferWriteControl[i]);
      inputDoubleBuffer.readAddress[i](inputBufferReadAddress[i]);
      inputDoubleBuffer.readControl[i](inputBufferReadControl[i]);
    }

    weightController.clk(clk);
    weightController.rstn(rstn);
    weightController.addressRequest(weightAddressRequest);
    weightController.dataResponse(weightDataResponse);
    weightController.paramsIn(weightControllerParams);

    weightDoubleBuffer.clk(clk);
    weightDoubleBuffer.rstn(rstn);

    for (int i = 0; i < 2; i++) {
      weightController.bufferWriteAddress[i](weightBufferWriteAddress[i]);
      weightController.bufferWriteData[i](weightBufferWriteData[i]);
      weightController.bufferWriteControl[i](weightBufferWriteControl[i]);
      weightController.bufferReadAddress[i](weightBufferReadAddress[i]);
      weightController.bufferReadControl[i](weightBufferReadControl[i]);

      weightDoubleBuffer.writeAddress[i](weightBufferWriteAddress[i]);
      weightDoubleBuffer.writeData[i](weightBufferWriteData[i]);
      weightDoubleBuffer.writeControl[i](weightBufferWriteControl[i]);
      weightDoubleBuffer.readAddress[i](weightBufferReadAddress[i]);
      weightDoubleBuffer.readControl[i](weightBufferReadControl[i]);
    }

    inputDoubleBuffer.output(inputsToSystolicArray);
    weightDoubleBuffer.output(weightsToSystolicArray);

    matrixProcessor.clk(clk);
    matrixProcessor.rstn(rstn);

    matrixProcessor.inputsChannel(inputsToSystolicArray);
    matrixProcessor.weightsChannel(weightsToSystolicArray);
    matrixProcessor.outputsChannel(outputsFromSystolicArray);
    matrixProcessor.paramsIn(matrixProcessorParams);

    vectorUnit.clk(clk);
    vectorUnit.rstn(rstn);
    vectorUnit.paramsIn(vectorParamsIn);
    vectorUnit.systolicArrayOutput(outputsFromSystolicArray);
    vectorUnit.vectorFetchAddressRequest(vectorFetchAddressRequest);
    vectorUnit.vectorFetchDataResponse(vectorFetchDataResponse);
    vectorUnit.vectorUnitOutput(vectorUnitOutput);
    vectorUnit.outputAddress(outputAddress);
    vectorUnit.done(done);

    SC_THREAD(run);
    sensitive << clk.pos();
    async_reset_signal_is(rstn, false);
  };

  void run() {
    paramsIn.Reset();
    inputControllerParams.ResetWrite();
    weightControllerParams.ResetWrite();
    matrixProcessorParams.ResetWrite();

    wait();

    while (true) {
      Params params = paramsIn.Pop();
      inputControllerParams.Push(params);
      weightControllerParams.Push(params);
      matrixProcessorParams.Push(params);
    }
  }
};

#endif  // THIRD_PARTY_DNN_ACCELERATOR_HLS_SYSTEMC_TENSORCORE_H_
